/******************************************************************************
 * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
 * qtfs licensed under the Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 * http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
 * PURPOSE.
 * See the Mulan PSL v2 for more details.
 * Author: Liqiang
 * Create: 2023-11-23
 * Description: socket api in user-mode
 *******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <signal.h>
#include <sys/epoll.h>
#include <netinet/ip.h>
#include <netinet/in.h>
#include <sys/un.h>
#include <netinet/udp.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <linux/vm_sockets.h>

#include "log.h"
#include "libsocket.h"

struct lib_sock_arg {
	int cs;		// client(1) or server(2)
	int sfamily; 	// vsock or tcp or uds
	int stype;  // SOCK_DGRAM or SOCK_STREAM

	struct sockaddr_storage saddr;
};

static inline int check_sock_arg(struct lib_sock_arg *arg)
{
	if (arg->cs != LIBSOCK_CLIENT && arg->cs != LIBSOCK_SERVER) {
		log_err("build new connection role invalid(%d) must be CLIENT(%d) or SERVER(%d)", arg->cs, LIBSOCK_CLIENT, LIBSOCK_SERVER);
		return -1;
	}
	if (arg->sfamily != AF_VSOCK && arg->sfamily != AF_INET && arg->sfamily != AF_UNIX) {
		log_err("build new connection family invalid(%d), just support AF_UNIX(%d)/AF_INET(%d)/AF_VSOCK(%d).",
			arg->sfamily, AF_UNIX, AF_INET, AF_VSOCK);
		return -1;
	}
	if (arg->stype != SOCK_DGRAM && arg->stype != SOCK_STREAM) {
		log_err("build new connection type invalid(%d), just support SOCK_DGRAM or SOCK_STREAM.",
			arg->stype, SOCK_DGRAM, SOCK_STREAM);
		return -1;
	}
	return 0;
}

static inline int get_sock_len(int family)
{
	switch (family) {
		case AF_VSOCK:
			return sizeof(struct sockaddr_vm);
		case AF_INET:
			return sizeof(struct sockaddr_in);
		case AF_UNIX:
			return sizeof(struct sockaddr_un);
		default:
			break;
	}
	log_err("invalid family:%d", family);
	return -1;
}

static int libsock_build_connection(struct lib_sock_arg *arg)
{
	int ret;
#define MAX_LISTEN_NUM 64
	if (check_sock_arg(arg) != 0) {
		log_err("Arg error, please check!");
		return -1;
	}

	int sockfd = socket(arg->sfamily, arg->stype, 0);
	if (sockfd < 0) {
		log_err("As %s failed, socket fd:%d, errno:%d.",
			(arg->cs == LIBSOCK_CLIENT) ? "client" : "server", sockfd, errno);
		return -1;
	}

	if (arg->cs == LIBSOCK_SERVER) {
		if ((ret = bind(sockfd, (struct sockaddr *)&arg->saddr, get_sock_len(arg->sfamily))) < 0) {
			log_err("As server failed socklen:%d, bind ret:%d error:%d", get_sock_len(arg->sfamily), ret, errno);
			goto err_ret;
		}
		if ((ret = listen(sockfd, MAX_LISTEN_NUM)) < 0) {
			log_err("As server listen failed ret:%d errno:%d", ret, errno);
			goto err_ret;
		}
	} else {
		if ((ret = connect(sockfd, (struct sockaddr *)&arg->saddr, get_sock_len(arg->sfamily))) < 0) {
			log_err("As client failed socklen:%d, connect ret:%d errno:%d", get_sock_len(arg->sfamily), ret, errno);
			goto err_ret;
		}
	}
	return sockfd;

err_ret:
	close(sockfd);
	return -1;
}

int libsock_accept(int sockfd, int family)
{
	struct sockaddr_storage saddr;
	socklen_t len = get_sock_len(family);
	int connfd = accept(sockfd, (struct sockaddr *)&saddr, &len);
	if (connfd <= 0) {
		log_err("Accept failed sockfd:%d family:%d ret:%d errno:%d", sockfd, family, connfd, errno);
		return -1;
	}
	return connfd;
}

int libsock_build_inet_connection(char *ip, unsigned short port, enum libsock_cs_e cs)
{
	struct lib_sock_arg arg;
	struct sockaddr_in *in;
	in = (struct sockaddr_in *)&arg.saddr;

	memset(&arg, 0, sizeof(struct lib_sock_arg));
	in->sin_family = AF_INET;
	in->sin_port = htons(port);
	in->sin_addr.s_addr = inet_addr(ip);
	arg.cs = cs;
	arg.sfamily = AF_INET;
	arg.stype = SOCK_STREAM;

	int sockfd = libsock_build_connection(&arg);
	if (sockfd < 0) {
		log_err("build inet connection failed, ip:%s port:%u", ip, port);
		return -1;
	}
	return sockfd;
}

int libsock_build_vsock_connection(unsigned int cid, unsigned int port, enum libsock_cs_e cs)
{
	struct lib_sock_arg arg;
	struct sockaddr_vm *vm;
	vm = (struct sockaddr_vm *)&arg.saddr;

	memset(&arg, 0, sizeof(struct lib_sock_arg));
	vm->svm_family = AF_VSOCK;
	vm->svm_port = port;
	vm->svm_cid = cid;
	arg.cs = cs;
	arg.sfamily = AF_VSOCK;
	arg.stype = SOCK_STREAM;

	int sockfd = libsock_build_connection(&arg);
	if (sockfd < 0) {
		log_err("build vsock connection failed, cid:%u port:%u", cid, port);
		return -1;
	}
	return sockfd;
}

